{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Symbolic Math: Differentiation Program\n",
    "\n",
    "This notebook describes how to write a program that will manipulate mathematical expressions, with an emphasis on differentiation.  The program can compute $\\frac{d}{dx} x^2 + \\sin(x)$ to get the result $2x + \\cos(x)$.  Usually in a programming language, when we say `sin(x)`  we are computing with *numbers*, not *symbolic expressions*.  That is, programming languages have built-in functions like `sin` and numbers like 0 so that we can ask for the value of `sin(0)` and get back 0.  But now we want to do something different: treat $\\sin(x)$ as a *symbolic expression* that represents any value of $x$.  That facility is not built-in to most programming languages, so we need a way to represent and manipulate these symbolic expressions. Python has a large system called `sympy` to do just that, but we will develop a smaller, simpler system from scratch.\n",
    "\n",
    "# Representing Symbolic Expressions as Tuples\n",
    "\n",
    "Here is one simple way to represent symbolic expressions:  We will represent numbers as themselves, because they are already built-in to Python.  We will represent math variable (such as $x$) and constants (such as   $\\pi$) as Python strings (such as `'x'` and `'pi'`). And we will represent compound expressions (such as $x + 1$) as Python tuples, such as `('x', '+', 1)`. \n",
    "\n",
    "We will use the notion of [arity](http://en.wikipedia.org/wiki/Arity) of an expression. An expression can be: \n",
    "\n",
    "* Binary (arity = 2): $(x + y)$ is a binary expression because it has two sub-expressions, $x$ and $y$.\n",
    "* Unary (arity = 1): $-x$ and $\\sin(x)$ are both unary expressions, because they each have one sub-expression,  $x$.\n",
    "* Nullary (arity = 0): $x$ and $3$ are both nullary, or *atomic* expressions, because they have no sub-expressions,\n",
    "\n",
    "We can define this as a Python function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def arity(expr):\n",
    "    \"The number of sub-expressions in this expression.\"\n",
    "    if isinstance(expr, tuple):\n",
    "        return len(expr) - 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exp = ('x', '+', 1)\n",
    "arity(exp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This table summarizes the kinds of expressions and how we will represent them:\n",
    "\n",
    "<table>\n",
    "<tr><th>Math<br>Expression<th>Math Type<th>Arity<th>Python<br>Representation<th>Python<br>Type\n",
    "<tr><td>$3$<td>Integer<td>0<td><tt>3<td><tt>int\n",
    "<tr><td>$3.14$<td>Number<td>0<td><tt>3.14<td><tt>float\n",
    "<tr><td>$x$<td>Variable<td>0<td><tt>'x'<td><tt>str\n",
    "<tr><td>$\\pi$<td>Constant<td>0<td><tt>'pi'<td><tt>str\n",
    "<tr><td>$-x$<td>Unary op<td>1<td><tt>('-', 'x')<td><tt>tuple\n",
    "<tr><td>$\\sin(x)$<td>Unary func<td>1<td><tt>('sin', 'x')<td><tt>tuple\n",
    "<tr><td>$x+1$<td>Binary op<td>2<td><tt>('x', '+', 1)<td><tt>tuple\n",
    "<tr><td>$3x - y/2$<td>Nested<br>Binary op<td>2<td><tt>((3, '*', 'x'), '-', <br>&nbsp;('y', '/', 2))<td><tt>tuple\n",
    "\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Differentiation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In Calculus, *differentiation* is the process of computing the *derivative* of a function: the rate of change of the function with respect to a variable. We denote the dirivative of $y$ with respect to the variable $x$ as $\\frac{dy}{dx}$.  \n",
    "\n",
    "We can compute the derrivative  by consulting a [table of differentiation rules](http://myhandbook.info/form_diff.html).  Here are the rules for variables, constants, unary negation, and the four basic binary arithmetic operations: "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$\\frac{d}{dx} (x) = 1$\n",
    "\n",
    "$\\frac{d}{dx} (c) = 0$\n",
    "\n",
    "$\\frac{d}{dx} (-u) = - \\frac{du}{dx}$\n",
    "\n",
    "$\\frac{d}{dx} (u + v) = \\frac{du}{dx} + \\frac{dv}{dx}$\n",
    "\n",
    "$\\frac{d}{dx} (u - v) = \\frac{du}{dx} - \\frac{dv}{dx}$\n",
    "\n",
    "$\\frac{d}{dx} (u \\times v) = v \\frac{du}{dx} + u \\frac{dv}{dx}$\n",
    "\n",
    "$\\frac{d}{dx} (u \\; / \\; v) = (v \\frac{du}{dx} - u \\frac{dv}{dx}) \\; / \\; v^2$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining `D(y, x)`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will define the Python function `D(y, x)` to compute the symbolic derivative  of the expression `y` with respect to the variable `x`. We can handle each of the seven rules, one by one.  If the input is of a form that does not match one of the rules, we will raise a `ValueError`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def D(y, x='x'):\n",
    "    \"\"\"Return the symbolic derivative, dy/dx.\n",
    "    Handles binary operators +, -, *, /, and unary -, over variables and constants.\"\"\"\n",
    "    if y == x:            \n",
    "        return 1        # d/dx (x) = 1\n",
    "    if arity(y) == 0:\n",
    "        return 0        # d/dx (c) = 0\n",
    "    if arity(y) == 1:\n",
    "        (op, u) = y     # y is a compund expression of the form (op, u)\n",
    "        if op == '-':   return ('-', D(u, x)) # d/dx (-u) = - du/dx  \n",
    "    if arity(y) == 2: \n",
    "        (u, op, v) = y  # y is a compound expression of the form (u, op, v)\n",
    "        if op == '+':   return D(u, x), '+', D(v, x)\n",
    "        if op == '-':   return D(u, x), '-', D(v, x)\n",
    "        if op == '*':   return (v, '*', D(u, x)), '+',  (u, '*', D(v, x))\n",
    "        if op == '/':   return ((v, '*', D(u, x)), '-',  (u, '*', D(v, x))), '/', (v, '*', v)\n",
    "    raise ValueError(\"D can't handle this: \" + str(y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what `D` can do:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, '+', 0)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(('x', '+', 1), 'x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((('x', '*', 0), '+', (3, '*', 1)), '+', 0)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(((3, '*', 'x'), '+', 'c'), 'x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(('y', '*', 1), '+', ('y', '*', 1))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(('y', '*', 'y'), 'y')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These answers are all mathematically *correct*, but they are not *simplified*&mdash;better answers would be `1`, `3`, and `(2, '*', 'y')`, respectively. So we have solved the basic problem, but there are enhancements we could make in several directions:\n",
    "\n",
    "* *Simplify*: Return `1` instead of `(1, '+', 0)`.\n",
    "* *Diversify*: Handle more types of expressions, like $x^2$, and $\\sin(x)$.\n",
    "* *Refactor*: Maker the code cleaner by introducing a better representation for expressions.\n",
    "\n",
    "I choose to refactor first (because that will make the other two tasks easier). \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Representing Symbolic Expressions as Class Objects"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I will refactor the representation of expressions, replacing tuples with a user-defined class, `Expression`.  Why do that? Three reasons:\n",
    "\n",
    "* **Hygiene**: The code is cleaner with a separate type. A tuple can be used for a lot of things; when I have an expression I want to know for sure it is an expression.  \n",
    "* **Ease of Coding**: I will be able to form a new expression with the simple code `x + 1` rather than the more verbose `(x, '+', 1)`.\n",
    "* **Prettier Output**: We can see the prettier `(x + 1)` form on output, rather than the uglier `('x', '+', 1)`.\n",
    "\n",
    "Here's what I have in mind:\n",
    "\n",
    "\n",
    "<table>\n",
    "<tr><th>Math<br>Expression<th>Math<br>Type<th>Arity<th>Python<br>Representation<th>Python<br>Type\n",
    "<tr><td>$3$<td>Integer<td>0<td><tt>3<td><tt>int\n",
    "<tr><td>$3.14$<td>Number<td>0<td><tt>3.14<td><tt>float\n",
    "<tr><td>$x$<td>Variable<td>0<td><tt>Expression('x')<td><tt>Expression\n",
    "<tr><td>$\\pi$<td>Constant<td>0<td><tt>Expression('pi')<td><tt>Expression\n",
    "<tr><td>$-x$<td>Unary op<td>1<td><tt>Expression('-', 'x')<td><tt>Expression\n",
    "<tr><td>$\\sin(x)$<td>Unary func<td>1<td><tt>Expression('sin', 'x')<td><tt>Expression\n",
    "<tr><td>$x+1$<td>Binary op<td>2<td><tt>Expression('+', x, 1)</tt> (given def for `x`)<td><tt>Expression\n",
    "<tr><td>$3x - y/2$<td>Nested<br>Binary op<td>2<td><tt>3 * x - y / 2</tt>&nbsp;&nbsp;&nbsp;(given defs for `x`, `y`)<td><tt>Expression\n",
    "</table>\n",
    "\n",
    "Note that not every math expression is an instance of the new class `Expression`; we still use good old Python `int` and `float` for numbers. But variables, constants (like $\\pi$ or $c$), and unary and binary expressions are all instances of `Expression`.  Note also that I considered having separate subclasses of `Expression` for each of these, but I decided the complication was not worth it.\n",
    "\n",
    "To understand the implementation of `Expression`, you need to understand Python's [special method names](https://docs.python.org/3/reference/datamodel.html#special-method-names), the method names that begin and end with double underscores, like `__init__` and `__add__`. If you're not familiar with the concept, glance at [the documentation](https://docs.python.org/3/reference/datamodel.html#special-method-names), but here is a quick summary:\n",
    "\n",
    "* `__init__(self, ...)` *The constructor that defines what happens when a new object is created.*\n",
    "* `__add__(self, other)` *The method called when the code asks for `self + other`. Similar for `__sub__`, etc.*\n",
    "* `__radd__(self, other)` *'r' for reverse. Called when the code asks for, say, `1 + self`, where `1` is not an `Expression`.  Similar for `__rsub__`, etc.*\n",
    "* `__neg__(self)` *The method called when the code says `(- self)`.  Similar for `__pos__`.*\n",
    "* `__hash__(self)` *Computes a hash value, an integer, used as a key into a hash table (`dict`).*\n",
    "* `__eq__(self, other)` *Called to check whether `self == other`.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "\n",
    "class Expression(object): \n",
    "    \"A mathematical expression (other than a number).\"\n",
    "    def __init__(self, op, *args): self.op, self.args = op, args\n",
    "        \n",
    "    def __add__(self, other):  return Expression('+',  self, other)\n",
    "    def __sub__(self, other):  return Expression('-',  self, other)\n",
    "    def __mul__(self, other):  return Expression('*',  self, other)\n",
    "    def __truediv__(self, other): return Expression('/',  self, other)\n",
    "    def __pow__(self, other):  return Expression('**', self, other)\n",
    "    \n",
    "    def __radd__(self, other): return Expression('+',  other, self)\n",
    "    def __rsub__(self, other): return Expression('-',  other, self)\n",
    "    def __rmul__(self, other): return Expression('*',  other, self)\n",
    "    def __rtruediv__(self, other): return Expression('/',  other, self)\n",
    "    def __rpow__(self, other): return Expression('**', other, self)\n",
    "    \n",
    "    def __neg__(self):         return Expression('-', self)\n",
    "    def __pos__(self):         return self\n",
    "    \n",
    "    def __hash__(self):        return hash(self.op) ^ hash(self.args)\n",
    "    def __ne__(self, other):   return not (self == other)\n",
    "    def __eq__(self, other):   return (isinstance(other, Expression) \n",
    "                                       and self.op == other.op \n",
    "                                       and self.args == other.args)\n",
    "    def __repr__(self):\n",
    "        \"A string representation of the Expression.\"\n",
    "        op, args = self.op, self.args\n",
    "        n = arity(self)\n",
    "        if n == 0: \n",
    "            return op\n",
    "        if n == 2: \n",
    "            return '({} {} {})'.format(args[0], op, args[1])\n",
    "        if n == 1:\n",
    "            arg = str(args[0])\n",
    "            if arg.startswith('(') or op in {'+', '-'}:\n",
    "                return '{}{}'.format(op, arg)\n",
    "            else:\n",
    "                return '{}({})'.format(op, arg)\n",
    "\n",
    "def arity(expr):\n",
    "    \"The number of sub-expressions in this expression.\"\n",
    "    if isinstance(expr, Expression):\n",
    "        return len(expr.args)\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first define some symbols:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "a,    b,   c,   d,   u,   v,   w,   x,   y,   z,   pi ,  e = map(Expression, [\n",
    "'a', 'b', 'c', 'd', 'u', 'v', 'w', 'x', 'y', 'z', 'pi', 'e'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, for example, `a` is now an object of type `Expression` such that `a.op` is the string `'a'` and `a.args` is the empty tuple, and `arity(a)` is 0. The Python expression `a + 1` is equivalent to `a.__add__(1)`, which invokes the `Expression.__add__` method to create  a new object of type `Expression`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(a + 1)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Python expression `1 + a` is equivalent to `1.__add__(a)`.  But since the integer `1` doesn't know how to add a `Symbol`, Python instead calls `a.__radd__(1)`, where `radd` stands for \"reversed add\":"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1 + a)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 + a"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try something more interesting:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((-b + ((b ** 2) - ((4 * a) * c))) / (2 * a))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(-b + (b**2 - 4*a*c)) / (2 * a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Rewriting `D(y, x)` with the Expression Class\n",
    "===\n",
    "\n",
    "We can now rewrite `D`; notice that the code is a bit neater, since we don't have to create tuples, we can just use operators like `+`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def D(y, x=x):\n",
    "    \"\"\"Return the symbolic derivative, dy/dx.\n",
    "    Handles binary operators +, -, *, /, and unary -, over variables and constants.\"\"\"\n",
    "    if y == x:            \n",
    "        return 1        # d/dx (x) = 1\n",
    "    if arity(y) == 0:\n",
    "        return 0        # d/dx (c) = 0\n",
    "    op = y.op\n",
    "    if arity(y) == 1:\n",
    "        u = y.args[0]\n",
    "        if op == '-':   return -D(u, x) # d/dx (-u) = - du/dx\n",
    "    if arity(y) == 2: \n",
    "        (u, v) = y.args\n",
    "        if op == '+':   return D(u, x) + D(v, x)\n",
    "        if op == '-':   return D(u, x) - D(v, x)\n",
    "        if op == '*':   return D(u, x) * v +  D(v, x) * u\n",
    "        if op == '/':   return (v * D(u, x) - u * D(v, x)) / v ** 2\n",
    "    raise ValueError(\"D can't handle this: \" + str(y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(x + 2, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(((0 * x) + 3) + 0)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(3 * x + c, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1 * y) + (1 * y))"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(y * y, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((0 * x) + (1 * -c))"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(- c * x, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Representing Functions\n",
    "===\n",
    "\n",
    "Let's represent expressions like $\\sin(x)$. We'll need to be clear of the difference between the function `sin` and the functional expression `sin(x)`.  I choose to do that by making `sin` be an object of type `Function`, which is a subtype of `Expression`, and making `sin(x)` be an `Expression` whose `op` is `sin`:\n",
    "\n",
    "<table>\n",
    "<tr><th>Expression<th>Type<th>Arity<th>op<th>args\n",
    "<tr><td>`sin`<td>`Function`<td>0<td>`'sin'`<td>`()`\n",
    "<tr><td>`sin(x)`<td>`Expression`<td>1<td>`sin`<td>`(x,)`\n",
    "</table>\n",
    " \n",
    "\n",
    "What operations does `Function` need to support?  We need to be able to create a new one, with an op and no args; it can inherit the `__init__` method from `Expression` to do this.  Once we create a new one, it needs to be able to create a functional expression, like `sin(x)`.  We can do that with the `__call__` method of `Function`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class Function(Expression):\n",
    "    \"A mathematical function of one argument, like sin or ln.\"\n",
    "    def __call__(self, x): return Expression(self, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sin,    cos,   tan,   cot,   sec,   csc,   ln,   sqrt = map(Function, [\n",
    "'sin', 'cos', 'tan', 'cot', 'sec', 'csc', 'ln', 'sqrt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sin"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'args': (), 'op': 'sin'}"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vars(sin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sin(x)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sin(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'args': (x,), 'op': sin}"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vars(sin(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((-b + sqrt((b ** 2) - ((4 * a) * c))) / (2 * a))"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(-b + sqrt(b**2 - 4*a*c)) / (2 * a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((sin(x) ** 2) + (cos(x) ** 2))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sin(x)**2 + cos(x)**2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "Extending `D(y, x)` to handle Exponentiation and Functions\n",
    "====\n",
    "\n",
    "In this section, we extend `D(y, x)` to handle:\n",
    "\n",
    "* The binary exponentiation operator, $u^v$, denoted with `u ** v`.\n",
    "* The six trigonometric functions sin, cos, tan, cot, csc, sec, and the natural logarithm functions, ln.\n",
    "* The chain rule for nested functions.\n",
    "\n",
    "Here are the rules:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$\\frac{d}{dx} (x^n) = n x^{(n-1)}$   *if n is an integer*\n",
    "\n",
    "$\\frac{d}{dx} (u^v) = v u^{(v-1)} * \\frac{du}{dx} + u^v * \\ln(u) * \\frac{dv}{dx}$\n",
    "\n",
    "$\\frac{d}{dx} \\sin(x) = \\cos(x)$\n",
    "\n",
    "$\\frac{d}{dx} \\cos(x) = -\\sin(x)$\n",
    "\n",
    "$\\frac{d}{dx} \\tan(x) = \\sec(x)^2$\n",
    "\n",
    "$\\frac{d}{dx} \\cot(x) = -\\csc(x)^2$\n",
    "\n",
    "$\\frac{d}{dx} \\sec(x) = \\sec(x) \\tan(x)$\n",
    "\n",
    "$\\frac{d}{dx} \\csc(x) = - \\csc(x) \\cot(x)$\n",
    "\n",
    "$\\frac{d}{dx} \\ln(x) = 1 / x$\n",
    "\n",
    "$\\frac{dy}{dx} = \\frac{dy}{du} \\frac{du}{dx}$\n",
    "\n",
    "We can add these rules to the ones we had before, giving us:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def D(y, x=x):\n",
    "    \"\"\"Return the symbolic derivative, dy/dx.\n",
    "    Handles binary operators +, -, *, /, and unary -, as well as\n",
    "    transcendental trig and log functions over variables and constants.\"\"\"\n",
    "    if y == x:            \n",
    "        return 1        # d/dx (x) = 1\n",
    "    if arity(y) == 0:\n",
    "        return 0        # d/dx (c) = 0\n",
    "    if arity(y) == 1:\n",
    "        op, (u,) = y.op, y.args\n",
    "        if op == '+':  return D(u, x)\n",
    "        if op == '-':  return -D(u, x)\n",
    "        if u  != x:    return D(y, u) * D(u, x) ## CHAIN RULE\n",
    "        if op == sin:  return cos(x)\n",
    "        if op == cos:  return -sin(x)\n",
    "        if op == tan:  return sec(x) ** 2\n",
    "        if op == cot:  return -csc(x) ** 2\n",
    "        if op == sec:  return sec(x) * tan(x)\n",
    "        if op == csc:  return -csc(x) * cot(x)\n",
    "        if op == ln:   return 1 / x  \n",
    "    if arity(y) == 2: \n",
    "        op, (u, v) = y.op, y.args\n",
    "        if op == '+':  return D(u, x) + D(v, x)\n",
    "        if op == '-':  return D(u, x) - D(v, x)\n",
    "        if op == '*':  return D(u, x) * v +  D(v, x) * u\n",
    "        if op == '/':  return (v * D(u, x) - u * D(v, x)) / v ** 2\n",
    "        if op == '**': return (v * u**(v-1) if u == x and isinstance(v, int) else\n",
    "                               v * u**(v-1) * D(u, x) + u**v * ln(u) * D(v, x))\n",
    "    raise ValueError(\"D can't handle this: \" + str(y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chain Rule\n",
    "---\n",
    "\n",
    "The trickiest part is the chain rule.  Consider $\\frac{d}{dx} \\sin(\\ln(x))$.  Let:\n",
    "\n",
    "$y = \\sin(\\ln(x))$\n",
    "\n",
    "$u = \\ln(x)$\n",
    "\n",
    "According to the chain rule, \n",
    "\n",
    "$\\frac{dy}{dx} = \\frac{dy}{du} \\frac{du}{dx}$.  \n",
    "\n",
    "So we have:\n",
    "\n",
    "$\\frac{dy}{du} = \\frac{d}{du} \\sin(u) = \\cos(u) = \\cos(\\ln(x))$\n",
    "\n",
    "$\\frac{du}{dx} = \\frac{d}{dx} \\ln(x) = \\frac{1}{x}$\n",
    "\n",
    "Finally, we can multiply them together to get:\n",
    "\n",
    "$\\frac{dy}{dx} = \\frac{dy}{du} \\;\\frac{du}{dx} = \\cos(\\ln(x)) \\; \\frac{1}{x} = \\cos(\\ln(x)) / x$.\n",
    "\n",
    "\n",
    "Now let's do the same thing in Python. Let:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "Y = sin(ln(x))\n",
    "U = ln(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So let's go on with the computation of the derivative, `D(y, x) = D(y, u) * D(u, x)`. We'll look at the two terms on the right:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "cos(ln(x))"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(Y, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1 / x)"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(U, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And multiply the two terms together:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(cos(ln(x)) * (1 / x))"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(Y, U) * D(U, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can do the same computation all in one step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(cos(ln(x)) * (1 / x))"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(Y, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "More Examples of Differentiation\n",
    "===\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First an example where all is well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3 * (x ** 2))"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(x ** 3, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But most of the time, `D` gives an answer that is technically correct, but not in simplified form. The example below should yield `2*a*x + b`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((((0 * (x ** 2)) + ((2 * (x ** 1)) * a)) + ((0 * x) + (1 * b))) + 0)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(a*x**2 + b*x + c, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next example should give `50 * (5*x - 2)**9`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(((10 * (((5 * x) - 2) ** 9)) * (((0 * x) + 5) - 0)) + (((((5 * x) - 2) ** 10) * ln((5 * x) - 2)) * 0))"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D((5*x - 2)**10, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And for the next example, we should get $\\frac{d}{dx} \\sin(\\ln(x^2)) = 2 \\cos(\\ln(x^2)) / x$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(cos(ln(x ** 2)) * ((1 / (x ** 2)) * (2 * (x ** 1))))"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D(sin(ln(x**2)), x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In each case, the answer is correct, but not simplified."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simplification: `simp(y)`\n",
    "------\n",
    "\n",
    "\n",
    "Now let's work on simplification.  We'll define `simp(y)` to return an expression that is equivalent to the expression $y$, but simpler. We should say that the notion of *simpler* is not at all precise.  Which is simpler, $x^2 -x -12$ or $(x + 3)(x - 4)$?  It depends on what you want to do next.  But we can all agree that $x$ is simpler than $(2 - 1)x + 0$.  Our version of `simp` will just do the easiest of simplifications.  There's lots we could add to make it better.\n",
    "\n",
    "Another difficult issue is dealing with indeterminate results like dividing by 0.  We would like to be able to simplify $x/x$ to 1.  But that is only true if $x$ is not 0.  We'll ignore that problem and go ahead and simplify $x/x$ to 1, but we'll feel a little uneasy about it. We'll also simplify 0/0 to `undefined` and 1/0 to `infinity`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def simp(y):\n",
    "    \"Simplify an expression.\"\n",
    "    if arity(y) == 0:\n",
    "        return y \n",
    "    y = Expression(y.op, *map(simp, y.args))  ## Simplify the sub-expressions first\n",
    "    op = y.op\n",
    "    if arity(y) == 1:\n",
    "        (u,) = y.args\n",
    "        if y in simp_table:\n",
    "            return simp_table[y]\n",
    "        if op == '+':\n",
    "            return u                         # + u = u\n",
    "        if op == '-':\n",
    "            if arity(u) == 1 and u.op == '-':\n",
    "                return u.args[0]             # - - w = w\n",
    "    if arity(y) == 2:\n",
    "        (u, v) = y.args\n",
    "        if evaluable(u, op, v): # for example, (3 + 4) simplifies to 7\n",
    "            return eval(str(u) + op + str(v), {})\n",
    "        if op == '+':\n",
    "            if u == 0: return v              # 0 + v = v\n",
    "            if v == 0: return u              # u + 0 = u\n",
    "            if u == v: return 2 * u          # u + u = 2 * u\n",
    "        if op == '-':\n",
    "            if u == v: return 0              # u - v = 0\n",
    "            if v == 0: return u              # u - 0 = u\n",
    "            if u == 0: return -v             # 0 - v = -v\n",
    "        if op == '*':\n",
    "            if u == 0 or v == 0: return 0    # 0 * v = u * 0 = 0\n",
    "            if u == 1: return v              # 1 * v = v\n",
    "            if v == 1: return u              # u * 1 = u\n",
    "            if u == v: return u ** 2         # u * u = u^2\n",
    "        if op == '/':\n",
    "            if u == 0 and v == 0: return undefined # 0 / 0 = undefined\n",
    "            if u == v: return 1              # u / u = 1\n",
    "            if v == 1: return u              # u / 1 = u\n",
    "            if u == 0: return 0              # 0 / v = 0\n",
    "            if v == 0: return infinity       # u / 0 = infinity\n",
    "        if op == '**':\n",
    "            if v == 1: return u              # u ** 1 = u\n",
    "            if u == v == 0: return undefined # 0 ** 0 = undefined\n",
    "            if v == 0: return 1              # u ** 0 = 1  \n",
    "\n",
    "    # If no rules apply, return y unchanged.\n",
    "    return y\n",
    "\n",
    "from numbers import Number\n",
    "\n",
    "# Deal with infinity and with undefined numbers (like infinity minus infinity)\n",
    "infinity = float('inf') \n",
    "undefined = nan = (infinity - infinity)\n",
    "\n",
    "# Table of known exact values for certain functions.\n",
    "# Use this, for example, to simplify sin(pi) to 0 or ln(e) to 1.\n",
    "\n",
    "simp_table = {\n",
    "    sin(0): 0, sin(pi): 0,  sin(pi/2): 1,        sin(2*pi): 0,\n",
    "    cos(0): 1, cos(pi): -1, cos(pi/2): 1,        cos(2*pi): 1,\n",
    "    tan(0): 0, tan(pi): 0,  tan(pi/2): infinity, tan(2*pi): 0, tan(pi/4): 1,\n",
    "    ln(1): 0,  ln(e): 1,    ln(0): -infinity}\n",
    "\n",
    "def evaluable(u, op, v):\n",
    "    \"Can we evaluate (u op v) to a number? True if u and v are numbers, and not a special case like 0^0.\"\n",
    "    return (isinstance(u, Number) and isinstance(v, Number)\n",
    "            and not (op == '/' and (v == 0 or u % v != 0))\n",
    "            and not (op == '^' and (u == v == 0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try it out.  First an easy one; $1x + 0$ should be $x$: Simplifying "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "x"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp(x*(y/y) + 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, $(2 - 1)x^{\\cos(y-y)} + \\tan(\\pi)$ should be just $x$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "x"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp((2 - 1) * x ** cos(y-y) + tan(pi))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are two examples where `simp` helps, but does not do all the simplifications it could:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(cos(ln(x ** 2)) * ((1 / (x ** 2)) * (2 * x)))"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp(D(sin(ln(x**2)), x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10 * (((5 * x) - 2) ** 9)) * 5)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp(D((5 * x - 2)**10, x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp(cos(pi - pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simp(D(3 * x + c, x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
